# -*- coding: utf-8 -*-
"""Location: ./plugins/vault/vault_plugin.py
Copyright 2025
SPDX-License-Identifier: Apache-2.0
Authors: Mihai Criveti

Vault Plugin.

Generates bearer tokens from vault-saved tokens based on OAUTH2 config protecting a tool.

Hook: tool_pre_invoke
"""

# Standard
from enum import Enum
import json
from urllib.parse import urlparse

# Third-Party
from pydantic import BaseModel

# First-Party
from mcpgateway.db import get_db
from mcpgateway.plugins.framework import (
    HttpHeaderPayload,
    Plugin,
    PluginConfig,
    PluginContext,
    ToolPreInvokePayload,
    ToolPreInvokeResult,
)
from mcpgateway.services.gateway_service import GatewayService
from mcpgateway.services.logging_service import LoggingService

# Initialize logging service first
logging_service = LoggingService()
logger = logging_service.get_logger(__name__)


class VaultHandling(Enum):
    """Vault token handling modes.

    Attributes:
        RAW: Use raw token from vault.
    """

    RAW = "raw"


class SystemHandling(Enum):
    """System identification handling modes.

    Attributes:
        TAG: Identify system from gateway tags.
        OAUTH2_CONFIG: Identify system from OAuth2 config.
    """

    TAG = "tag"
    OAUTH2_CONFIG = "oauth2_config"


class VaultConfig(BaseModel):
    """Configuration for vault plugin.

    Attributes:
        system_tag_prefix: Prefix for system tags.
        vault_header_name: HTTP header name for vault tokens.
        vault_handling: Vault token handling mode.
        system_handling: System identification mode.
        auth_header_tag_prefix: Prefix for auth header tags (e.g., "AUTH_HEADER").
    """

    system_tag_prefix: str = "system"
    vault_header_name: str = "X-Vault-Tokens"
    vault_handling: VaultHandling = VaultHandling.RAW
    system_handling: SystemHandling = SystemHandling.TAG
    auth_header_tag_prefix: str = "AUTH_HEADER"


class Vault(Plugin):
    """Vault plugin that based on OAUTH2 config that protects a tool will generate bearer token based on a vault saved token"""

    def __init__(self, config: PluginConfig):
        """Initialize the vault plugin.

        Args:
            config: Plugin configuration.
        """
        super().__init__(config)
        # load config with pydantic model for convenience
        try:
            self._sconfig = VaultConfig.model_validate(self._config.config or {})
        except Exception:
            self._sconfig = VaultConfig()

    def _parse_vault_token_key(self, key: str) -> tuple[str, str | None, str | None, str | None]:
        """Parse vault token key in format: system[:scope][:token_type][:token_name].

        Args:
            key: Token key to parse (e.g., "github.com:USER:OAUTH2:TOKEN" or "github.com").

        Returns:
            Tuple of (system, scope, token_type, token_name). Missing parts are None.
        """
        parts = key.split(":")
        system = parts[0] if len(parts) > 0 else key
        scope = parts[1] if len(parts) > 1 else None
        token_type = parts[2] if len(parts) > 2 else None
        token_name = parts[3] if len(parts) > 3 else None
        return system, scope, token_type, token_name

    async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult:
        """Generate bearer tokens from vault-saved tokens before tool invocation.

        Args:
            payload: The tool payload containing arguments.
            context: Plugin execution context.

        Returns:
            Result with potentially modified headers containing bearer token.
        """
        logger.debug(f"Processing tool pre-invoke for tool {payload}  with context {context}")
        logger.debug(f"Gateway metadata {context.global_context.metadata['gateway']}")

        gateway_metadata = context.global_context.metadata["gateway"]

        system_key: str | None = None
        auth_header: str | None = None
        if self._sconfig.system_handling == SystemHandling.TAG:
            # Extract tags from dict format {"id": "...", "label": "..."}
            normalized_tags: list[str] = []
            for tag in gateway_metadata.tags:
                if isinstance(tag, dict):
                    # Use 'label' field (the actual tag value)
                    tag_value = str(tag.get("label", ""))
                    if tag_value:
                        normalized_tags.append(tag_value)
                elif hasattr(tag, "label"):
                    normalized_tags.append(str(getattr(tag, "label")))
            # Find system tag with the configured prefix
            system_prefix = self._sconfig.system_tag_prefix + ":"
            system_tag = next((tag for tag in normalized_tags if tag.startswith(system_prefix)), None)
            if system_tag:
                system_key = system_tag.split(system_prefix)[1]
                logger.info(f"Using vault system from GW tags: {system_key}")
            # Find auth header tag with the configured prefix (e.g., "AUTH_HEADER:X-GitHub-Token")
            auth_header_prefix = self._sconfig.auth_header_tag_prefix + ":"
            auth_header_tag = next((tag for tag in normalized_tags if tag.startswith(auth_header_prefix)), None)
            if auth_header_tag:
                auth_header = auth_header_tag.split(auth_header_prefix)[1]
                logger.info(f"Found AUTH_HEADER tag: {auth_header}")

        elif self._sconfig.system_handling == SystemHandling.OAUTH2_CONFIG:
            gen = get_db()
            db = next(gen)
            try:
                gateway_service = GatewayService()
                gw_id = context.global_context.server_id
                if gw_id:
                    gateway = await gateway_service.get_gateway(db, gw_id)
                    logger.info(f"Gateway used {gateway.oauth_config}")
                    if gateway.oauth_config and "token_url" in gateway.oauth_config:
                        token_url = gateway.oauth_config["token_url"]
                        parsed_url = urlparse(token_url)
                        system_key = parsed_url.hostname
                        logger.info(f"Using vault system from oauth_config: {system_key}")
            finally:
                gen.close()

        if not system_key:
            logger.warning("System cannot be determined from gateway metadata.")
            return ToolPreInvokeResult()

        modified = False
        headers: dict[str, str] = payload.headers.model_dump() if payload.headers else {}

        # Check if vault header exists
        if self._sconfig.vault_header_name not in headers:
            logger.debug(f"Vault header '{self._sconfig.vault_header_name}' not found in headers")
            return ToolPreInvokeResult()

        try:
            vault_tokens: dict[str, str] = json.loads(headers[self._sconfig.vault_header_name])
        except (json.JSONDecodeError, TypeError) as e:
            logger.error(f"Failed to parse vault tokens from header: {e}")
            return ToolPreInvokeResult()

        vault_handling = self._sconfig.vault_handling

        # Try to find matching token in vault_tokens
        # First try exact match with system_key
        token_value: str | None = None
        token_key_used: str | None = None
        if system_key in vault_tokens:
            token_value = str(vault_tokens[system_key])
            token_key_used = str(system_key)
            logger.info(f"Found exact match for system key: {system_key}")
        else:
            # Try to find a key that starts with system_key (complex key format)
            for key in vault_tokens.keys():
                parsed_system, scope, token_type, token_name = self._parse_vault_token_key(key)
                if parsed_system == system_key:
                    token_value = vault_tokens[key]
                    token_key_used = key
                    logger.info(f"Found matching token with complex key: {key} (system: {parsed_system}, scope: {scope}, type: {token_type}, name: {token_name})")
                    break

        if token_value and token_key_used:
            # Parse the token key to determine handling
            parsed_system, scope, token_type, token_name = self._parse_vault_token_key(token_key_used)
            # Determine how to handle the token based on token_type and AUTH_HEADER tag
            if token_type == "PAT":
                # Handle Personal Access Token
                logger.info(f"Processing PAT token for system: {parsed_system}")
                # Check if AUTH_HEADER tag is defined
                if auth_header:
                    logger.info(f"Using AUTH_HEADER tag for {parsed_system}: header={auth_header}")
                    headers[auth_header] = str(token_value)
                    modified = True
                else:
                    # No AUTH_HEADER tag, use default Bearer token
                    logger.info(f"No AUTH_HEADER tag found for {parsed_system}, using Bearer token")
                    headers["Authorization"] = f"Bearer {token_value}"
                    modified = True
            elif token_type == "OAUTH2" or token_type is None:
                # Handle OAuth2 token or default behavior (when token_type is missing)
                if vault_handling == VaultHandling.RAW:
                    logger.info(f"Set Bearer token for system: {parsed_system}")
                    headers["Authorization"] = f"Bearer {token_value}"
                    modified = True
            else:
                # Unknown token type, use default behavior
                logger.warning(f"Unknown token type '{token_type}', using default Bearer token")
                if vault_handling == VaultHandling.RAW:
                    headers["Authorization"] = f"Bearer {token_value}"
                    modified = True

            # Remove vault header after processing
            if modified and self._sconfig.vault_header_name in headers:
                del headers[self._sconfig.vault_header_name]

            payload.headers = HttpHeaderPayload(root=headers)

        if modified:
            logger.info(f"Modified tool '{payload.name}' to add auth header")
            return ToolPreInvokeResult(modified_payload=payload)

        return ToolPreInvokeResult()

    async def shutdown(self) -> None:
        """Shutdown the plugin gracefully.

        Returns:
            None.
        """
        return None
